import torch
import torch.nn as nn
import torch.nn.functional as F


class MLP(torch.nn.Module):
    def __init__(
            self,
            msg_dim,
            model_dim=200,
            num_ff_layers=1,
            expansion_factor=2,
            dropout=0.0,
            device=None,
            **kwargs
    ):
        super().__init__()

        self.msg_dim = msg_dim
        self.device = device
        self.dropout = dropout

        self.criterion = nn.BCEWithLogitsLoss()

        updates = [nn.Linear(msg_dim, model_dim)]
        for i in range(num_ff_layers):
            updates.append(FeedForward(model_dim, expansion_factor, dropout))
        updates.append(nn.Linear(model_dim, 1))
        self.layers = nn.Sequential(*updates).to(device)

    def reset_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_normal_(p)

    def reset_state(self):
        pass

    def detach(self):
        pass

    def forward(self, batch):
        batch = batch.to(self.device)
        logits = self.layers(batch.msg)
        y = batch.y.float().unsqueeze(-1)
        loss = self.criterion(logits, y)
        y_pred = logits.detach().cpu().sigmoid()
        y_true = batch.y.detach().cpu().float().unsqueeze(-1)
        return loss, y_pred, y_true


class FeedForward(nn.Module):
    """
    2-layer MLP with GeLU (fancy version of ReLU) as activation
    """
    def __init__(self, dim, expansion_factor, dropout=0.0):
        super().__init__()
        self.dim = dim
        self.expansion_factor = expansion_factor
        self.dropout = dropout
        self.linear_0 = nn.Linear(dim, int(expansion_factor * dim))
        self.linear_1 = nn.Linear(int(expansion_factor * dim), dim)
        self.reset_parameters()

    def reset_parameters(self):
        self.linear_0.reset_parameters()
        self.linear_1.reset_parameters()

    def forward(self, x):
        x = self.linear_0(x)
        x = F.leaky_relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.linear_1(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        return x